#!/usr/bin/env python3

"""
Given this linked list: 1->2->3->4->5
For k = 2, you should return: 2->1->4->3->5
For k = 3, you should return: 3->2->1->4->5
"""

class ListNode:
    def __init__(self, val):
        self.val = val
        self.next = None

    def show(self):
        print(str(self.val), end=' ')
        if self.next:
            self.next.show()

class Solution:

    def __init__(self):
        self.next_head = None
        self.top_head = None
    
    def reverse_k_group(self, head, k):
        if head is None:
            return None
        result = self.reverse_n(head, k)
        top_head = self.top_head
        if result:
            if self.next_head:
                result.next = self.reverse_k_group(self.next_head, k)
            else:
                result.next = None
            return top_head
        else:
            return head

    def reverse_n(self, head, n):
        if n == 1:
            self.next_head = head.next
            self.top_head = head
            return head
        if head.next:
            reverse_parent = self.reverse_n(head.next, n - 1)
            if reverse_parent:
                reverse_parent.next = head
                return head
        return None

if __name__ == '__main__':
    head = None
    current = None
    for node in [1]:
        if head is None:
            head = current = ListNode(node)
        else:
            current.next = ListNode(node)
            current = current.next
    solution = Solution()
    solution.reverse_k_group(head, 2).show()
    print()
